# Copyright (C) 2022-2025, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

from collections.abc import Callable
from typing import cast

import torch

from .methods.core import _CAM


class ClassificationMetric:
    r"""Implements Average Drop and Increase in Confidence from ["Grad-CAM++: Improved Visual Explanations for Deep
    Convolutional Networks."](https://arxiv.org/pdf/1710.11063.pdf).

    The raw aggregated metric is computed as follows:

    $$
    \forall N, H, W \in \mathbb{N}, \forall X \in \mathbb{R}^{N \times 3 \times H \times W},
    \forall m \in \mathcal{M}, \forall c \in \mathcal{C}, \\
    AvgDrop_{m, c}(X) = \frac{1}{N} \sum\limits_{i=1}^N f_{m, c}(X_i) \\
    IncrConf_{m, c}(X) = \frac{1}{N} \sum\limits_{i=1}^N g_{m, c}(X_i)
    $$

    where $\mathcal{C}$ is the set of class activation generators,
    $\mathcal{M}$ is the set of classification models,
    with the function $f_{m, c}$ defined as:

    $$
    \forall x \in \mathbb{R}^{3 \times H \times W},
    f_{m, c}(x) = \frac{\max(0, m(x) - m(E_{m, c}(x) * x))}{m(x)}
    $$

    where $E_{m, c}(x)$ is the class activation map of $m$ for input $x$ with method $m$,
    resized to (H, W),

    and with the function $g_{m, c}$ defined as:

    $$
    \forall x \in \mathbb{R}^{3 \times H \times W},\quad
    g_{m, c}(x) =
    \begin{cases}
        1 & \text{if } m(x) < m(E_{m, c}(x) \cdot x) \\
        0 & \text{otherwise}
    \end{cases}
    $$

    Example:
        ```python
        from functools import partial
        from torchcam.metrics import ClassificationMetric
        metric = ClassificationMetric(cam_extractor, partial(torch.softmax, dim=-1))
        metric.update(input_tensor)
        metric.summary()
        ```
    """

    def __init__(
        self,
        cam_extractor: _CAM,
        logits_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
    ) -> None:
        # This is a typa, I don't know how to rites
        self.cam_extractor = cam_extractor
        self.logits_fn = logits_fn
        self.reset()

    def _get_probs(self, input_tensor: torch.Tensor) -> torch.Tensor:
        logits = self.cam_extractor.model(input_tensor)
        return cast(torch.Tensor, logits if self.logits_fn is None else self.logits_fn(logits))

    def update(
        self,
        input_tensor: torch.Tensor,
        class_idx: int | None = None,
    ) -> None:
        """Update the state of the metric with new predictions.

        Args:
            input_tensor: preprocessed input tensor for the model
            class_idx: class index to focus on (default: index of the top predicted class for each sample)
        """
        self.cam_extractor.model.eval()
        probs = self._get_probs(input_tensor)
        # Take the top preds for the cam
        if isinstance(class_idx, int):
            cams = self.cam_extractor(class_idx, probs)
            cam = self.cam_extractor.fuse_cams(cams)
            probs = probs[:, class_idx]
        else:
            preds = probs.argmax(dim=-1)
            cams = self.cam_extractor(preds.cpu().numpy().tolist(), probs)
            cam = self.cam_extractor.fuse_cams(cams)
            probs = probs.gather(1, preds.unsqueeze(1)).squeeze(1)
        self.cam_extractor.disable_hooks()
        # Safeguard: skip NaNs
        discard = torch.isnan(cam).reshape(input_tensor.shape[0], -1).any(dim=-1)
        cam = cam[~discard, ...]
        probs = probs[~discard]
        if class_idx is None:
            preds = preds[~discard]
        input_tensor = input_tensor[~discard]
        # Resize the CAM
        cam = torch.nn.functional.interpolate(cam.unsqueeze(1), input_tensor.shape[-2:], mode="bilinear")
        # Create the explanation map & get the new probs
        with torch.inference_mode():
            masked_probs = self._get_probs(cam * input_tensor)
        masked_probs = (
            masked_probs[:, class_idx]
            if isinstance(class_idx, int)
            else masked_probs.gather(1, preds.unsqueeze(1)).squeeze(1)
        )
        # Drop (avoid division by zero)
        drop = torch.relu(probs - masked_probs).div(probs + 1e-7)

        # Increase
        increase = probs < masked_probs

        self.cam_extractor.enable_hooks()

        self.drop += drop.sum().item()
        self.increase += increase.sum().item()
        self.total += cam.shape[0]
        self.nan_count += discard.sum().item()

    def summary(self) -> dict[str, float]:
        """Computes the aggregated metrics.

        Returns:
            a dictionary with the average drop and the increase in confidence

        Raises:
            AssertionError: if the metric has not been updated
        """
        if self.total == 0:
            raise AssertionError("you need to update the metric before getting the summary")

        return {
            "avg_drop": self.drop / self.total,
            "conf_increase": self.increase / self.total,
        }

    def reset(self) -> None:
        """Reset the state of the metric."""
        self.drop = 0.0
        self.increase = 0.0
        self.total = 0
        self.nan_count = 0
